from __future__ import annotations

import json
import logging
import reprlib
import time
from dataclasses import dataclass, field
from json.decoder import JSONDecodeError
from typing import Callable, Optional, TypeVar

import requests
import urllib3.exceptions
from typing_extensions import ParamSpec, Protocol, runtime_checkable

from lightly.openapi_generated.swagger_client.models.api_error_code import ApiErrorCode

logger = logging.getLogger("lightly-worker")


@runtime_checkable
class APIExceptionLike(Protocol):
    """Protocol for API exceptions.

    Implement the relevant attributes of the openapi-generated APIException class.
    See https://github.com/lightly-ai/lightly/blob/3642a12510895c4e917bc05801be613aab86fde2/lightly/openapi_generated/swagger_client/exceptions.py#L103-L115

    We cannot directly use API exception classes, as there might be multiple of them.
    """

    status: int | None
    body: str | None


RETRY_HTTP_ERROR_CODES = [
    400,  # Bad Request. Jeremy: This error is typically caused by the client but could also be caused by the server.
    408,  # Timeout (e.g google storage returns this)
    409,  # Conflict
    429,  # Too Many Requests
    500,  # Internal Server Error
    502,  # Bad Gateway
    503,  # Service Unavailable
    504,  # Gateway Timeout
]

# The api client generated by swagger uses urllib3 under the hood. It does not
# specify a retry strategy for urllib3 and also only catches certain errors
# raised by urllib3 but not all of them! Therefore api calls can result in
# urllib3 errors which we should retry on. Urllib3 without retries catches,
# potentially converts, and finally  reraises all errors listed here:
# https://github.com/urllib3/urllib3/blob/972e9f02cd219892701fa0a1129dda7f81dac535/src/urllib3/connectionpool.py#L764-L773
# Because some errors are converted, we do not have to catch all errors in the
# list but only the ones that can be reraised by urrlib3.
RETRY_API_URLLIB3_ERROR = (
    urllib3.exceptions.TimeoutError,
    urllib3.exceptions.ProtocolError,  # urllib3 converts HTTPException and OSError to ProtocolError
    urllib3.exceptions.SSLError,  # urllib3 converts BaseSSLError and CertificateError to SSLError
    urllib3.exceptions.ProxyError,
)

# Requests uses urllib3 under the hood which means that sometimes urllib3 errors
# can be raised from requests code. Here we list all the urrlib3 errors that
# should be retried when calling requests. The retried errors are the same as
# the ones caught by requests internally: https://github.com/psf/requests/blob/da9996fe4dc63356e9467d0a5e10df3d89a8528e/requests/models.py#L815-L825
# Note that requests only catches these errors when the iter_content
# function on the response object is called. But we usually use
# response.raw directly without calling iter_content and have to catch
# errors ourselves.
RETRY_REQUESTS_URLLIB3_ERROR = (
    urllib3.exceptions.DecodeError,
    urllib3.exceptions.ReadTimeoutError,
    urllib3.exceptions.ProtocolError,
    urllib3.exceptions.SSLError,
)

# Requests errors that we don't want to retry on. See
# https://github.com/psf/requests/blob/main/requests/exceptions.py for possible errors.
NOT_RETRY_REQUESTS_ERROR = (
    requests.exceptions.URLRequired,
    requests.exceptions.MissingSchema,
    requests.exceptions.InvalidSchema,
    requests.exceptions.InvalidURL,
    requests.exceptions.InvalidHeader,
    requests.exceptions.RetryError,
    requests.exceptions.JSONDecodeError,
    requests.exceptions.InvalidJSONError,
)


class MaxRetryError(Exception):
    """Error raised when the maximum number of retries are reached."""

    pass


class DatasetNotFoundError(Exception):
    """Exception raised when a dataset is not found, possibly due to deletion."""

    pass


@dataclass
class RetryConfig:
    """Base configuration for retry behavior.

    Attributes:
        max_retries:
             Maximum number of retries.
         backoff_factor:
             A backoff factor to apply between attempts after the second try.
             Will sleep for {backoff factor} * (2 ** ({number of total retries} - 1))
             seconds after the second try. Will never be more than backoff_max.
             Behavior is identical to urllib3.util.Retry.
         backoff_max:
             Maximum time in seconds to wait between retries.
    """

    max_retries: int = 5
    backoff_factor: float = 1.0
    backoff_max: float = 120.0


@dataclass
class RetryOnApiConfig(RetryConfig):
    """Configuration for API retry behavior.

    Attributes:
        backoff_min_on_429: Minimum time to wait between retries after a 429 error.
        retry_http_error_codes: HTTP error codes to retry on.
        retry_api_error_codes: API error codes to retry on.
        retry_api_urllib3_error: urllib3 error types to retry on for API calls.
    """

    backoff_min_on_429: float = 10.0
    retry_http_error_codes: set[int] = field(
        default_factory=lambda: set(RETRY_HTTP_ERROR_CODES)
    )
    retry_api_error_codes: set[ApiErrorCode] = field(default_factory=set)
    retry_api_urllib3_error: tuple[type[Exception], ...] = field(
        default_factory=lambda: RETRY_API_URLLIB3_ERROR
    )


@dataclass
class RetryOnRequestsConfig(RetryOnApiConfig):
    """Configuration for requests retry behavior.

    Attributes:
        retry_requests_urllib3_error: urllib3 error types to retry on for requests calls.
        not_retry_requests_error: Request error types to not retry on.
    """

    retry_requests_urllib3_error: tuple[type[Exception], ...] = field(
        default_factory=lambda: RETRY_REQUESTS_URLLIB3_ERROR
    )
    not_retry_requests_error: tuple[type[requests.RequestException], ...] = field(
        default_factory=lambda: NOT_RETRY_REQUESTS_ERROR
    )


# argument and return types
ParamType = ParamSpec("ParamType")
ReturnType = TypeVar("ReturnType")


class Retry:
    """Base class to retry function calls when they fail with an exception.

    Subclasses must overwrite should_retry().

    Attributes:
        config: Configuration for retry behavior, including max retries, backoff settings,
            and exception types to retry on.
    """

    def __init__(self, config: RetryConfig):
        self.config = config

    def __call__(
        self,
        fn: Callable[ParamType, ReturnType],
        *args: ParamType.args,
        **kwargs: ParamType.kwargs,
    ) -> ReturnType:
        """Tries to execute fn with the given args and kwargs."""
        retries = 0
        while True:
            try:
                return fn(*args, **kwargs)
            except BaseException as ex:
                if retries >= self.config.max_retries:
                    raise MaxRetryError(
                        f"Calling '{fn.__name__}' failed after {retries} attempt(s). "
                        f"Args: {reprlib.repr(args)}; kwargs: {reprlib.repr(kwargs)}. "
                        f"Last error: {self.format_error(self._wrap_exception(ex))}"
                    ) from ex
                if not self.should_retry(ex):
                    raise self._wrap_exception(ex)
                # TODO(Guarin, 02/23): Disabled logging because it can be called from
                # a different process than the main process, resulting in a potential
                # deadlock. Enable logging again once we switch to using spawn or
                # forkserver to start processes. See LIG-2486.
                # logger.debug(
                #     f"Retrying: Retry {retries + 1} of {self.config.max_retries} after {repr(ex)}"
                # )
                backoff = self.calculate_backoff(retries, ex)
                time.sleep(backoff)
            retries += 1

    def format_error(self, exception: BaseException) -> str:
        """Writes errors in a readable format.

        Args:
            exception: Exception raised.

        Returns:
            Formatted error message.
        """
        error_details = str(exception) or "None"
        return f"{exception.__class__.__name__}. Details: {error_details}"

    def should_retry(self, ex: BaseException) -> bool:
        """
        Takes an exception as input and returns True
        if the failed call should be retried and False otherwise.
        If False is returned then the original exception is raised.
        """
        raise NotImplementedError()

    def calculate_backoff(self, retries: int, ex: BaseException) -> float:
        """Returns the time to sleep before the next try."""
        if retries <= 0:
            return 0
        backoff: float = self.config.backoff_factor * (2 ** (retries - 1))
        return max(0.0, min(backoff, self.config.backoff_max))

    def _wrap_exception(self, error: BaseException) -> BaseException:
        return error


class RetryOnApiError(Retry):
    """Class to retry api calls when they fail with an exception.

    Attributes:
        config: Configuration for retry behavior, including max retries, backoff settings,
            and exception types to retry on.
    """

    config: RetryOnApiConfig

    def __init__(
        self,
        config: RetryOnApiConfig,
    ):
        super().__init__(config=config)

    def should_retry(self, ex: BaseException) -> bool:
        is_urllib_error = isinstance(ex, self.config.retry_api_urllib3_error)
        is_http_error = False
        is_api_error = False
        if isinstance(ex, APIExceptionLike):
            is_http_error = ex.status in self.config.retry_http_error_codes
            is_api_error = (
                _get_error_code_from_api_exception(ex=ex)
                in self.config.retry_api_error_codes
            )
        return is_urllib_error or is_http_error or is_api_error

    def calculate_backoff(self, retries: int, ex: BaseException) -> float:
        adjusted_retries = retries + 1  # relieve api by not immediately retrying
        backoff = super().calculate_backoff(adjusted_retries, ex)
        if isinstance(ex, APIExceptionLike):
            if ex.status == 429:
                # Additional backoff for too many requests error code
                backoff = max(backoff, self.config.backoff_min_on_429)
        return backoff

    def _wrap_exception(self, error: BaseException) -> BaseException:
        """Wraps exceptions to provide more meaningful error messages.

        Args:
            error: The original exception.

        Returns:
            The wrapped exception or the original exception if no wrapping is needed.
        """
        if isinstance(error, APIExceptionLike):
            error_code = _get_error_code_from_api_exception(error)
            if error_code == ApiErrorCode.DATASET_UNKNOWN:
                return DatasetNotFoundError(
                    "Dataset has suddenly disappeared and cannot be found. Did you delete it?"
                )

        return error


class RetryOnRequestsError(RetryOnApiError):
    """Retries on requests errors and api errors.

    We retry on api errors because redirected read urls are handled by our api but the
    original request in the worker is made using requests.

    We retry on urllib3 errors because requests uses urllib3 under the hood.
    """

    config: RetryOnRequestsConfig

    def __init__(
        self,
        config: RetryOnRequestsConfig,
    ):
        super().__init__(config=config)

    def should_retry(self, ex: BaseException) -> bool:
        if super().should_retry(ex):
            # Api error.
            return True
        elif isinstance(ex, self.config.retry_requests_urllib3_error):
            # Urllib3 error.
            return True
        elif isinstance(ex, self.config.not_retry_requests_error):
            # Requests error that we don't retry.
            return False
        elif isinstance(ex, requests.exceptions.HTTPError):
            # Retry requests with a missing response.
            if ex.response is None:
                return True

            # Retry the same HTTP errors that we also retry for the API.
            # Requests raises an error for any HTTP status code in [400, 599], see:
            # https://github.com/psf/requests/blob/6e5b15d542a4e85945fd72066bb6cecbc3a82191/requests/models.py#L1010-L1018
            return ex.response.status_code in self.config.retry_http_error_codes
        elif isinstance(ex, requests.exceptions.RequestException):
            # Retry on any other requests error.
            return True
        return False


# Default retry for all api calls
retry = RetryOnApiError(config=RetryOnApiConfig())


"""Helper function for fully retrying requests calls with a timeout.

This function can be used in conjunction with requests Sessions that already
have a retry strategy. The retry mechanism in requests does not fully retry
the request but only retries the last action. This can be a problem when
downloading files using streaming as requests will only retry the last read()
call to read the next chunk of the stream which fails, for example, if the
connection was interrupted. This function avoids the problem by restarting
the full download if a read() call raises an error.

"""
retry_on_requests_error = RetryOnRequestsError(config=RetryOnRequestsConfig())


def no_retry(
    fn: Callable[ParamType, ReturnType],
    *args: ParamType.args,
    **kwargs: ParamType.kwargs,
) -> ReturnType:
    return fn(*args, **kwargs)


def _get_error_code_from_api_exception(
    ex: APIExceptionLike,
) -> Optional[str]:
    """Returns api error code from ApiException.

    This is the "code" part from the failed api response. For example, a request failed
    with:

        HTTP response body: {
            "code": "MALFORMED_REQUEST",
            "error": "sampleId must match the following: \"/^[a-f0-9]{24}$/\"",
            "requestId": "a805582f5069db7696b9b66604873b3c"
        }

    will return "MALFORMED_REQUEST".
    """
    # TODO(Guarin, 03/23): Add debug log messages if ex.body has unexpected format once
    # we use process with spawn method.

    # Api error code exists only if returned body is json string.
    if not isinstance(ex.body, str):
        return None
    try:
        body = json.loads(ex.body)
    except JSONDecodeError:
        # Invalid json string.
        return None
    # We expect the api to return a json dict with a "code" entry.
    if isinstance(body, dict):
        return body.get("code")
    else:
        return None
